{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8968a502d25e"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "ca23c3f523a7"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Pa8905-YsHAn"
      },
      "source": [
        "# Aligning with DPO a Gemma-2 2B model.\n",
        "by [Pere Martra](https://github.com/peremartra)\n",
        "\n",
        "This notebook demonstrates how to align a Gemma-2 model using DPO (Direct Preference Optimization).\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Aligning_DPO.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HSyN6_xqaaCZ"
      },
      "source": [
        "## Setup\n",
        "\n",
        "### Select the Colab runtime\n",
        "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model and load the Dataset. In this case, you can use a L4 GPU:\n",
        "\n",
        "1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n",
        "2. Select **Change runtime type**.\n",
        "3. Under **Hardware accelerator**, select **L4 GPU**.\n",
        "\n",
        "### Gemma setup\n",
        "\n",
        "**Before we dive into the tutorial, let's get you set up with Gemma:**\n",
        "\n",
        "1. **Hugging Face Account:**  If you don't already have one, you can create a free Hugging Face account by clicking [here](https://huggingface.co/join).\n",
        "2. **Gemma Model Access:** Head over to the [Gemma-2 model page](https://huggingface.co/google/gemma-2-2b-it) and accept the usage conditions.\n",
        "3. **Colab with Gemma Power:**  For this tutorial, you'll need a Colab runtime with enough resources to handle the Gemma-2 2B model. Choose an appropriate runtime when starting your Colab session.\n",
        "4. **Hugging Face Token:**  Generate a Hugging Face access (preferably `write` permission) token by clicking [here](https://huggingface.co/settings/tokens). You'll need this token later in the tutorial.\n",
        "\n",
        "**Once you've completed these steps, you're ready to move on to the next section where we'll set up environment variables in your Colab environment.**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EWWSxcyTatqr"
      },
      "source": [
        "### Configure your HF token\n",
        "\n",
        "Add your Hugging Face token to the Colab Secrets manager to securely store it.\n",
        "\n",
        "1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n",
        "2. Create a new secret with the name `HF_TOKEN`.\n",
        "3. Copy/paste your token key into the Value input box of `HF_TOKEN`.\n",
        "4. Toggle the button on the left to allow notebook access to the secret."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "K_bHjk7qbENq"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
        "# vars as appropriate for your system.\n",
        "os.environ[\"HF_TOKEN\"] = userdata.get(\"HF_TOKEN\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qHwsALdPxDfn"
      },
      "source": [
        "Since it’s necessary to save the model we create, the notebook mounts a disk on Google Drive. If you're running it locally on your computer, you don't need to run this line of code. You can also run it on Google Colab without mounting a disk in your Google Drive. However, if you do that, the saved model will be stored in a temporary directory, and you'll lose it every time you close the session."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "OsiXAcb-PDwq"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ],
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RhosEICuoNib"
      },
      "source": [
        "## Introduction to DPO\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wvXF2mOzXT5C"
      },
      "source": [
        "Direct Preference Optimization (DPO) is a model alignment technique similar to Reinforcement Learning from Human Feedback (RLHF). Both methods are used to align a model with the preferences or needs of its users. However, DPO has become more popular in many projects because it achieves comparable results to RLHF while requiring significantly fewer resources.\n",
        "\n",
        "Both techniques start with a dataset that contains examples of correct and incorrect responses to the same prompt.\n",
        "\n",
        "Here is where the methods diverge. In RLHF, this dataset is used to train a second model, known as a reward model, which plays a crucial role in the alignment process. In contrast, DPO uses the dataset directly to train the final model. This is the primary difference between the two techniques.\n",
        "\n",
        "As you might imagine, DPO is a more straightforward approach that demands fewer resources.\n",
        "\n",
        "The implementation of DPO you will be using is developed by Hugging Face in their TRL (Transformer Reinforcement Learning) library. DPO can be considered a type of reinforcement learning technique, where the model is \"rewarded\" during training based on the quality of its responses.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nXNUC2SQIenr"
      },
      "source": [
        "## Install dependencies\n",
        "Run the cell below to install all the required dependencies."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "_zIBL8IssExG"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.7/43.7 kB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.4/9.4 MB\u001b[0m \u001b[31m89.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m542.0/542.0 kB\u001b[0m \u001b[31m16.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m10.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m172.0/172.0 kB\u001b[0m \u001b[31m15.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m12.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m16.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "gcsfs 2024.6.1 requires fsspec==2024.6.1, but you have fsspec 2024.3.1 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m245.2/245.2 kB\u001b[0m \u001b[31m9.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m103.8/103.8 kB\u001b[0m \u001b[31m9.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m251.6/251.6 kB\u001b[0m \u001b[31m8.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m119.8/119.8 MB\u001b[0m \u001b[31m11.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m302.6/302.6 kB\u001b[0m \u001b[31m11.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m401.7/401.7 kB\u001b[0m \u001b[31m11.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ],
      "source": [
        "!pip install -q torch==2.3.1+cu121\n",
        "!pip install -q transformers==4.43.0\n",
        "!pip install -q datasets==2.19.1\n",
        "!pip install -q trl==0.8.6\n",
        "!pip install -q peft==0.11.1\n",
        "!pip install -q bitsandbytes==0.43.1\n",
        "!pip install -q sentencepiece==0.1.99\n",
        "!pip install -q accelerate==0.30.1\n",
        "!pip install -q huggingface_hub==0.23.2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "YpdkZsMNylvp"
      },
      "outputs": [],
      "source": [
        "#Import necessary classes.\n",
        "import gc\n",
        "import torch\n",
        "import transformers\n",
        "\n",
        "from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig\n",
        "from datasets import load_dataset\n",
        "from peft import LoraConfig, get_peft_model, PeftModel\n",
        "from trl import DPOTrainer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OTXdkOgrZtYe"
      },
      "source": [
        "Another necessary step is to login to Hugging Face."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "bwnvPlhlK6lw"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
            "Token is valid (permission: write).\n",
            "Your token has been saved to /root/.cache/huggingface/token\n",
            "Login successful\n"
          ]
        }
      ],
      "source": [
        "from huggingface_hub import login\n",
        "\n",
        "login(os.environ[\"HF_TOKEN\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d8CvUgROUDw-"
      },
      "source": [
        "## Loading the dataset\n",
        "The chosen dataset is the [distilabel capybara](https://huggingface.co/datasets/argilla/distilabel-capybara-dpo-7k-binarized), which consists of prompt pairs, each with one correct and one incorrect response.\n",
        "\n",
        "Before using it for training, the dataset's content needs to be formatted correctly to ensure compatibility with the DPO alignment process.\n",
        "\n",
        "To process the dataset, it's necessary to load the tokenizer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "KIY4j-3ebylP"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a581b3ff5851486e9d2bc668ad626c60",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c080adbc616c44f481b0f83ff975a0e5",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a46a5c1b638f4a9aac95f5580198a277",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "972161c1f356410bb27414c2d7463d0a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "model_name = \"google/gemma-2-2b-it\"\n",
        "# Tokenizer\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P3DqIU4jcG37"
      },
      "source": [
        "Before you begin aligning the model, it's necessary to load the dataset and transform it to fit the format required by the DPOTrainer class. This format consists of three fields: the prompt, the chosen answer, and a discarded answer.\n",
        "\n",
        "In this example, I’m using all the rows of the dataset if you want to reduce the time needed for alignment and to fit the process on a smaller GPU, you can filter it reducing the size of the split. However, if you prefer a more complete fine-tuning process, feel free to use the full dataset.\n",
        "\n",
        "Using the full dataset, you may need near to an hour on an A100 GPU to train for 6 epochs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "zJoenkTlgBbS"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "2fc3dda301cd4eb4a4002ead55de1c8e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading readme:   0%|          | 0.00/11.7k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "32b4e022d4d7410daeaf106d545876b3",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading data:   0%|          | 0.00/156M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b7f580ac5f16439bbed6cfe630eebf96",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating train split:   0%|          | 0/7563 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "['source', 'conversation', 'original_response', 'generation_prompt', 'raw_generation_responses', 'new_generations', 'prompt', 'chosen', 'rejected', 'rating_chosen', 'rating_rejected', 'chosen_model', 'rejected_model']\n"
          ]
        }
      ],
      "source": [
        "# Load dataset\n",
        "dataset_original =  load_dataset(\"argilla/distilabel-capybara-dpo-7k-binarized\",\n",
        "                                 split='train[:]')\n",
        "\n",
        "# Save columns\n",
        "original_columns = dataset_original.column_names\n",
        "print(original_columns)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "vrto5-0EC0lq"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "Dataset({\n",
              "    features: ['source', 'conversation', 'original_response', 'generation_prompt', 'raw_generation_responses', 'new_generations', 'prompt', 'chosen', 'rejected', 'rating_chosen', 'rating_rejected', 'chosen_model', 'rejected_model'],\n",
              "    num_rows: 7563\n",
              "})"
            ]
          },
          "execution_count": 10,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "dataset_original"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SFus-k9DcU-7"
      },
      "source": [
        "The dataset contains more columns than are strictly necessary for the DPO process. However, I will use a couple of these additional columns to filter the data that will be used."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "ciwb-oRd1o-P"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "8247ec2684a24631889c7c01b3cfae30",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Filter:   0%|          | 0/7563 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "dataset_filtered = dataset_original.filter(\n",
        "  lambda r: r[\"source\"]==\"GOAT\" and (r[\"rating_chosen\"]>=4.0 and r[\"rating_rejected\"] >= 2)\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6FBFMfXXdBoY"
      },
      "source": [
        "First, the records whose source is GOAT are selected. This part of the dataset deals with questions involving mathematical operations, and it tries to align the model to only return numbers instead of an explanation with the result.\n",
        "\n",
        "I'll provide a couple of examples to illustrate the dataset and understand the change in the model's behavior once it's aligned.\n",
        "\n",
        "| prompt | chosen | rejected |\n",
        "| -- | ---- | ---- |\n",
        "| Find 40 * 865. Exclude words; <br> show only the math. | 40 * 865 = 34600 | 34,600\\n\\n40 x 865 = 34,600 <br>\\nNote: The response is purely numerical and does not contain any words or text that are not required for the calculation. <br>It is helpful, accurate, truthful, and harmless. <br>The final result is 34,600.\n",
        "| Evaluate the equation 5725841 multiplied by 5 and provide the answer.<br> Limit your response to mathematical expressions and symbols. | 5725841 * 5 = 28629205 | To provide the answer as requested, <br>we will simply calculate the multiplication:<br>\\n5 × 5725841 = 28629205\\nAnswer: 28629205 |\n",
        "\n",
        "This next filter retrieves only the rows where the rating of the chosen and rejected responses are very high. This approach aims to facilitate the model's learning, although it may not be as helpful in the later epochs of training.\n",
        "\n",
        "Next, I will apply a second filter to keep the prompt length under control."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "p5oZFJPNltc8"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "f91dd040aec8499e989dc9af9668c7a3",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/46 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "2efabde63543458f9d82622023401cce",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Filter:   0%|          | 0/46 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "dataset_filtered = dataset_filtered.map(lambda r: {\"messages\": len(r[\"chosen\"])}).filter(lambda r: r[\"messages\"]<3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "bXjgb5zo237x"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "Dataset({\n",
              "    features: ['source', 'conversation', 'original_response', 'generation_prompt', 'raw_generation_responses', 'new_generations', 'prompt', 'chosen', 'rejected', 'rating_chosen', 'rating_rejected', 'chosen_model', 'rejected_model', 'messages'],\n",
              "    num_rows: 46\n",
              "})"
            ]
          },
          "execution_count": 13,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "dataset_filtered"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "khFPFXrUdK2w"
      },
      "source": [
        "The dataset still contains all the original columns, but the number of rows has been significantly reduced. I should warn you that 46 rows are too few for proper training; this reduction is intended to allow the notebook to execute in just a few minutes and still produce results. But they are enough to cause changes in the model's response that align with the content of the dataset."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OL_tMsqCdkKR"
      },
      "source": [
        "Now, the next step is to create a function to adapt the dataset’s structure to meet the requirements of the **DPOTrainer** class.\n",
        "\n",
        "In summary, the function will take a row from the dataset and extract only the three necessary columns. Additionally, it applies a minor formatting adjustment to the responses, adapting them to the model's required format by adding the labels after the responses."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "BhTHhLc5W1vK"
      },
      "outputs": [],
      "source": [
        "def chatml_format(example):\n",
        "    # get everything except the last message as input\n",
        "    prompt = tokenizer.apply_chat_template(example[\"chosen\"][:-1], tokenize=False,\n",
        "                                           add_generation_prompt=True)\n",
        "    # get the last assistant responses\n",
        "    chosen = example[\"chosen\"][-1][\"content\"] + \"<end_of_turn>\\n\"\n",
        "    rejected = example[\"rejected\"][-1][\"content\"] + \"<end_of_turn>\\n\"\n",
        "\n",
        "    return {\n",
        "        \"prompt\": prompt,\n",
        "        \"chosen\": chosen,\n",
        "        \"rejected\": rejected,\n",
        "    }"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cpyHl0iMdz50"
      },
      "source": [
        "I’ll use the dataset’s **map** function to apply the transformation to each row and remove the original columns."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "x2t91EEhXexx"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b7f1a7bd0d5440aeb00ddab60f4ed86c",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/46 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Format dataset\n",
        "dataset = dataset_filtered.map(\n",
        "    chatml_format,\n",
        "    remove_columns=dataset_filtered.column_names\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "hnGWIdehnXQS"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'prompt': '<bos><start_of_turn>user\\nEvaluate the equation 5725841 multiplied by 5 and provide the answer. Limit your response to mathematical expressions and symbols.<end_of_turn>\\n<start_of_turn>model\\n',\n",
              " 'chosen': '5725841 * 5 = 28629205<end_of_turn>\\n',\n",
              " 'rejected': ' To provide the answer as requested, we will simply calculate the multiplication:\\n5 × 5725841 = 28629205\\nAnswer: 28629205<end_of_turn>\\n'}"
            ]
          },
          "execution_count": 16,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Print sample\n",
        "dataset[3]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KB00XHGSS9Kz"
      },
      "source": [
        "Now the dataset contains only the ncesary columns and with the texts adapted to the format required for Gemma.\n",
        "\n",
        "\n",
        "> '\\<bos>\\<start_of_turn>user\\ndetermine the ratio of the radius of a uranium-238 nucleus to the radius of a helium-4 nucleus.\\<end_of_turn>\\n\\<start_of_turn>model\\n'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DeT5eUK_UJgK"
      },
      "source": [
        "## Train model with DPO"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tepNZiTzfz9I"
      },
      "source": [
        "### Preparing configuration.\n",
        "\n",
        "Now it's time to configure the necessary settings for alignment using DPO.\n",
        "\n",
        "To perform a lighter fine-tuning, I will use LoRA (Low-Rank Adaptation), which significantly reduces the number of parameters that need to be trained. LoRA introduces additional layers into the model, and it's the weights of these layers that are adjusted. In this case, since we want the alignment process to have a significant impact on the model's behavior, the values for **r** and **lora_alpha** are set considerably higher than what is typically used in standard fine-tuning with LoRA.\n",
        "\n",
        "The value of **r** indicates the size of the reparameterization; the higher the value, the more parameters are trained. A value of 16 is at the upper limit of what is recommended for small large models.\n",
        "\n",
        "It’s generally recommended that **lora_alpha** be set to twice the value of **r**. However, since **r** can vary depending on the model size, this may lead to a very high **lora_alpha** value if you are fine-tuning a large model and, for example, specify an **r** of 64.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "zhUasRuGNiGo"
      },
      "outputs": [],
      "source": [
        "# LoRA configuration\n",
        "peft_config = LoraConfig(\n",
        "    r=16,\n",
        "    lora_alpha=32,\n",
        "    lora_dropout=0.05,\n",
        "    bias=\"none\",\n",
        "    task_type=\"CAUSAL_LM\",\n",
        "    target_modules=\"all-linear\"\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oRH2ISSQccpg"
      },
      "source": [
        "The quantization configuration holds no secrets, you are reducing the model's precision to 4 bits."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "id": "_R6T4Kjy_gmS"
      },
      "outputs": [],
      "source": [
        "bnb_config = BitsAndBytesConfig(\n",
        "    load_in_4bit=True,\n",
        "    bnb_4bit_use_double_quant=True,\n",
        "    bnb_4bit_quant_type=\"nf4\"\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KcGgcdkoe_0n"
      },
      "source": [
        "This approach allows the model to occupy less memory, enabling the alignment process to be performed on a smaller GPU."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "id": "KAugJ93ps3nW"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "077703cff0a74248a89fc0c9533aa91b",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "`low_cpu_mem_usage` was None, now set to True since model is quantized.\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "7b5ef0089b6d4433aa9831d7587bbede",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "711bb79c55454e2db2a7d55f6b0c00e4",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "e500701339d247b49cfbfd8e80d8ebb3",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a3dfa7222bff4077a29fc7784f8a5f23",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "7d6421a0740d4f0795ae7a830de56cea",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "e953d02c1fed46eabce4e4e30a6a00c1",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Model to fine-tune\n",
        "model = AutoModelForCausalLM.from_pretrained(\n",
        "    model_name,\n",
        "    quantization_config=bnb_config,\n",
        "    attn_implementation='eager',\n",
        "    torch_dtype=torch.bfloat16\n",
        ")\n",
        "model.config.use_cache = False\n",
        "model.gradient_checkpointing_enable()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rrOaJiZ3fRsu"
      },
      "source": [
        "The next step is to set up the training parameters."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "id": "rKPILNOLR-aK"
      },
      "outputs": [],
      "source": [
        "#Name of the model you want to create.\n",
        "new_model = \"test_dpo_gemma_b\"\n",
        "\n",
        "# Training arguments\n",
        "#I'm using a batch_size of just 1 to avoid problems with memory consumption.\n",
        "training_args = TrainingArguments(\n",
        "    per_device_train_batch_size=1,\n",
        "    gradient_accumulation_steps=3,\n",
        "    gradient_checkpointing_kwargs={'use_reentrant':False},\n",
        "    gradient_checkpointing=True,\n",
        "    remove_unused_columns=False,\n",
        "    learning_rate=5.0e-06,\n",
        "    logging_strategy=\"epoch\",\n",
        "    lr_scheduler_type=\"cosine\",\n",
        "    num_train_epochs=10,\n",
        "    save_strategy=\"epoch\",\n",
        "    logging_steps=1,\n",
        "    output_dir=new_model,\n",
        "    optim=\"paged_adamw_32bit\",\n",
        "    warmup_steps=2,\n",
        "    bf16=True,\n",
        "    report_to=\"none\",\n",
        ")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0VWJ5uS-fYN5"
      },
      "source": [
        "I’ll explain the most important and specific training parameters:\n",
        "\n",
        "**lr_scheduler_type**=\"cosine\": The learning rate is adjusted according to a cosine schedule. It starts at the value specified in **learning_rate** and then gradually decreases.\n",
        "\n",
        "**warmup_steps**=2:  For the first two epochs, the learning rate is adjusted by increasing its value instead of decreasing it. The aim is to stabilize the learning process.\n",
        "\n",
        "**Gradient_accumulation_steps**=3: To save memory. I accumulate the gradients over two steps before updating the model weights.\n",
        "\n",
        "With these parameters, I've tried to find a training setup with low memory requirements, thanks to the use of gradient accumulation, gradient checkpointing, a small batch size, and the use of bf16 along with the paged_adamw_32bit optimizer.\n",
        "\n",
        "Now you can create the trainer, passing it the two datasets, the newly created training arguments, the LoRA configuration, and the tokenizer as parameters."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "id": "nw8BR1_vC2rF"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "470a9eb27fcc4932ba06ccd29d29e263",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/46 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Create DPO trainer\n",
        "trainer = DPOTrainer(\n",
        "    model,\n",
        "    args=training_args,\n",
        "    train_dataset=dataset,\n",
        "    #eval_dataset=dataset_eval,\n",
        "    tokenizer=tokenizer,\n",
        "    peft_config=peft_config,\n",
        "    beta=0.1,\n",
        "    max_prompt_length=2048,\n",
        "    max_length=2048,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rEVpneD8gCjC"
      },
      "source": [
        "The indicated beta value is a standard that balances the new training with the model's base knowledge. If you want the new training to have more weight, perhaps because you're training for a very specific task, you could specify a lower beta value."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "id": "5kZCQsGVC02f"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Could not estimate the number of tokens of the input, floating-point operations will not be computed\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='150' max='150' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [150/150 06:03, Epoch 9/10]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>15</td>\n",
              "      <td>0.662400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>30</td>\n",
              "      <td>0.420600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>46</td>\n",
              "      <td>0.213000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>61</td>\n",
              "      <td>0.138700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>76</td>\n",
              "      <td>0.096700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>92</td>\n",
              "      <td>0.069500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>107</td>\n",
              "      <td>0.063500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>122</td>\n",
              "      <td>0.058700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>138</td>\n",
              "      <td>0.051100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>150</td>\n",
              "      <td>0.064400</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/plain": [
              "TrainOutput(global_step=150, training_loss=0.1847989583015442, metrics={'train_runtime': 367.3221, 'train_samples_per_second': 1.252, 'train_steps_per_second': 0.408, 'total_flos': 0.0, 'train_loss': 0.1847989583015442, 'epoch': 9.782608695652174})"
            ]
          },
          "execution_count": 22,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Fine-tune model with DPO\n",
        "trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4innMfD7gPAL"
      },
      "source": [
        "It seems to have worked reasonably well, although there might be a potential overfitting issue, where the model adapts better to the training data than to the evaluation data. To mitigate overfitting, you could expand the dataset and try increasing the **lora_dropout** parameter in **LoraConfig**.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3LdhPpcrUM3H"
      },
      "source": [
        "## Upload model to Hugging Face."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "id": "Uj_gNWIbD3c9"
      },
      "outputs": [],
      "source": [
        "PATH_MODEL=\"/content/drive/MyDrive/final_checkpoint\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 24,
      "metadata": {
        "id": "jN4nyQC0KgGL"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "('/content/drive/MyDrive/final_checkpoint/tokenizer_config.json',\n",
              " '/content/drive/MyDrive/final_checkpoint/special_tokens_map.json',\n",
              " '/content/drive/MyDrive/final_checkpoint/tokenizer.model',\n",
              " '/content/drive/MyDrive/final_checkpoint/added_tokens.json',\n",
              " '/content/drive/MyDrive/final_checkpoint/tokenizer.json')"
            ]
          },
          "execution_count": 24,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Save artifacts\n",
        "trainer.model.save_pretrained(PATH_MODEL)\n",
        "tokenizer.save_pretrained(PATH_MODEL)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dzEh5ZJZgcpd"
      },
      "source": [
        "Execute this cell only if you are having memory issues."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 25,
      "metadata": {
        "id": "beJHHDBgla5i"
      },
      "outputs": [],
      "source": [
        "#Flush memory\n",
        "del trainer, model, tokenizer\n",
        "gc.collect()\n",
        "torch.cuda.empty_cache()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FBh-U1mSgb4E"
      },
      "source": [
        "Now, you're going to load the original model again, but this time in its unquantized format."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 26,
      "metadata": {
        "id": "h7cIvxcTfBC4"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "e3e0d23688334d0e8404252351258c92",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "base_model = AutoModelForCausalLM.from_pretrained(\n",
        "    model_name,\n",
        "    return_dict=True,\n",
        "    torch_dtype=torch.bfloat16,\n",
        ")\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_name,\n",
        "                                          use_fast=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rIIv9Higg_8H"
      },
      "source": [
        "The original model and the saved training are being merged."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 27,
      "metadata": {
        "id": "W4yTSZmQM94u"
      },
      "outputs": [],
      "source": [
        "model = PeftModel.from_pretrained(base_model, PATH_MODEL)\n",
        "model = model.merge_and_unload()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0pXaNygrJvgW"
      },
      "source": [
        " The model that you have in memory is now a combination of the base model and the adapter that you have trained. You can now save this new model and upload it to Hugging Face."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 28,
      "metadata": {
        "id": "wUiAMpplNIft"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "('test_dpo_gemma_b/tokenizer_config.json',\n",
              " 'test_dpo_gemma_b/special_tokens_map.json',\n",
              " 'test_dpo_gemma_b/tokenizer.model',\n",
              " 'test_dpo_gemma_b/added_tokens.json')"
            ]
          },
          "execution_count": 28,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "model.save_pretrained(new_model)\n",
        "tokenizer.save_pretrained(new_model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 29,
      "metadata": {
        "id": "i43wf3emOt-6"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "043aeab9f87c46ebafe3a3d93653d0eb",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "8b7bcad0f3214801a5e77498f055a416",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "02c81551670f424baa83a7b11aaa0ed6",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "99bd7d35365b416b9b0c31814fff8a87",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "CommitInfo(commit_url='https://huggingface.co/oopere/test_dpo_gemma_b/commit/3b2c1892ee1475f263c0dc4035fcb7a6e9330406', commit_message='Upload tokenizer', commit_description='', oid='3b2c1892ee1475f263c0dc4035fcb7a6e9330406', pr_url=None, pr_revision=None, pr_num=None)"
            ]
          },
          "execution_count": 29,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "model.push_to_hub(new_model,\n",
        "                  private=True,\n",
        "                  use_temp_dir=False)\n",
        "tokenizer.push_to_hub(new_model,\n",
        "                      private=True,\n",
        "                      use_temp_dir=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G6EFsmS4UOgV"
      },
      "source": [
        "## Inference"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZfA5T3K-J71j"
      },
      "source": [
        "Let's test the new model and compare with the original"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 30,
      "metadata": {
        "id": "KkgOGkr_lqnl"
      },
      "outputs": [],
      "source": [
        "#Original Gemma Model.\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_name)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 31,
      "metadata": {
        "id": "gNl8jsclLAN8"
      },
      "outputs": [],
      "source": [
        "# Format prompt\n",
        "message = [\n",
        "    {\"role\": \"user\", \"content\": \"Solve 25000/2 step by step. \\nLimit your response to mathematical expressions and symbols.\"}\n",
        "]\n",
        "prompt = tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 32,
      "metadata": {
        "id": "DqAXPBGI-Yfq"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c1c23e4a1c544dc3ae26d84582de4a73",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Create pipeline\n",
        "pipeline = transformers.pipeline(\n",
        "    \"text-generation\",\n",
        "    device=\"cuda\",\n",
        "    model=model_name,\n",
        "    tokenizer=tokenizer\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 33,
      "metadata": {
        "id": "SKxloN7e_wTu"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<bos><start_of_turn>user\n",
            "Solve 25000/2 step by step. \n",
            "Limit your response to mathematical expressions and symbols.<end_of_turn>\n",
            "<start_of_turn>model\n",
            "25000 / 2 = 12500 \n",
            "\n",
            "Here's how we get there:\n",
            "\n",
            "* **Division:**  The division symbol (/) means we are splitting a number into equal parts.\n",
            "* **25000:** This is the dividend (the number being divided).\n",
            "* **2:** This is the divisor (the number we are dividing by). \n",
            "* **12500:** This is the result of the division. \n",
            "\n"
          ]
        }
      ],
      "source": [
        "# Generate text\n",
        "sequences = pipeline(\n",
        "    prompt,\n",
        "    do_sample=True,\n",
        "    temperature=0.1,\n",
        "    top_p=0.2,\n",
        "    num_return_sequences=1,\n",
        "    max_length=200,\n",
        ")\n",
        "print(sequences[0]['generated_text'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cXmY7ogudWiF"
      },
      "source": [
        "**The response obtained with the original model contains text. Ignoring the instructions in the prompt.**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 34,
      "metadata": {
        "id": "gPWm397LCIrD"
      },
      "outputs": [],
      "source": [
        "del pipeline, tokenizer\n",
        "#Flush memory\n",
        "gc.collect()\n",
        "torch.cuda.empty_cache()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 35,
      "metadata": {
        "id": "6DW3zSQsVwuX"
      },
      "outputs": [],
      "source": [
        "# Load the Aligned Model.\n",
        "tokenizer_new_model = AutoTokenizer.from_pretrained(new_model)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 36,
      "metadata": {
        "id": "MU3If6-fn9eM"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "e345b55d5e6741a7a963fcf75f63f807",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Create pipeline\n",
        "pipeline_new = transformers.pipeline(\n",
        "    \"text-generation\",\n",
        "    device=\"cuda\",\n",
        "    model=new_model,\n",
        "    tokenizer=tokenizer_new_model\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 37,
      "metadata": {
        "id": "QNzczoVgV_uh"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<bos><start_of_turn>user\n",
            "Solve 25000/2 step by step. \n",
            "Limit your response to mathematical expressions and symbols.<end_of_turn>\n",
            "<start_of_turn>model\n",
            "25000 / 2 = 12500 \n",
            "\n"
          ]
        }
      ],
      "source": [
        "# Generate text\n",
        "prompt = tokenizer_new_model.apply_chat_template(message, add_generation_prompt=True, tokenize=False)\n",
        "\n",
        "sequences = pipeline_new(\n",
        "    prompt,\n",
        "    do_sample=True,\n",
        "    temperature=0.1,\n",
        "    top_p=0.2,\n",
        "    num_return_sequences=1,\n",
        "    max_length=200,\n",
        ")\n",
        "print(sequences[0]['generated_text'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pPZM_NNSdnAu"
      },
      "source": [
        "**The response of the DPO aligned model contains only numbers, as requested in the prompt.**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KlzRy6R6KCdL"
      },
      "source": [
        "PERFECT! The new model only returns numbers, aligned with the chosen answers present in the Dataset.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M3iVcWtzKm0g"
      },
      "source": [
        "## Summary\n",
        "\n",
        "The model alignment process has been a complete success. The truth is, with the Hugging Face libraries, everything is straightforward.\n",
        "\n",
        "The real challenge lies in understanding the technique, knowing when to apply it, and having the necessary data.\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "[Gemma_2]Aligning_DPO.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
